Import and process data
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
# Tally each pair of co-occuring diagnoses within each criteria
# Original responses
df_gpt3.5_codiag <- create_codiagnosis_df(df_gpt3.5, remove_singletons = T)
df_gpt4.0_codiag <- create_codiagnosis_df(df_gpt4.0, remove_singletons = T)
df_claude3_haiku_t1.0_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0, remove_singletons = T)
df_claude3_opus_t1.0_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0, remove_singletons = T)
df_gemini1.0_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0, remove_singletons = T)
df_gemini1.5_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0, remove_singletons = T)
# ICD converted responses
df_gpt3.5_icd_codiag <- create_codiagnosis_df(df_gpt3.5_icd, remove_singletons = T)
df_gpt4.0_icd_codiag <- create_codiagnosis_df(df_gpt4.0_icd, remove_singletons = T)
df_claude3_haiku_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0_icd, remove_singletons = T)
df_claude3_opus_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0_icd, remove_singletons = T)
df_gemini1.0_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0_icd, remove_singletons = T)
df_gemini1.5_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0_icd, remove_singletons = T)
df_gpt4.0_codiag
Plot centrality
calculate_subgraph_centrality <- function(g, centrality_fun = "centrality_eigen"){
data.frame(criteria = g %>% activate(edges) %>% pull(criteria) %>% unique()) %>%
mutate(sub_graphs = map(criteria, function(c){
g %>% as_data_frame() %>% filter(criteria == c) %>% as_tbl_graph(directed = F) %>%
activate(nodes) %>% mutate(centrality = get(centrality_fun)()) %>% data.frame()
})) %>%
unnest(sub_graphs) %>%
pivot_wider(names_from = "name", values_from = "centrality", values_fill = 0) %>%
column_to_rownames("criteria")
}
graph_all_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)
df_centr_gpt4 <- calculate_subgraph_centrality(graph_all_gpt4)
centrality_similarity <- function(data){
data %>%
rownames_to_column("criteria") %>%
format_criteria() %>%
column_to_rownames("criteria") %>%
as.matrix() %>%
t() %>%
lsa::cosine() %>%
as.data.frame() %>%
rownames_to_column("V1") %>%
pivot_longer(-V1, names_to = "V2", values_to = "cosine")
}
centrality_similarity(df_centr_gpt4)
centrality_wrapper <- function(data, n_diag=NULL){
make_codiagnosis_graph(data, n_diagnoses = n_diag) %>%
calculate_subgraph_centrality() %>%
centrality_similarity()
}
average_cosine_matrix <- listN(
df_gpt3.5_icd_codiag,
df_gpt4.0_icd_codiag,
df_claude3_haiku_t1.0_icd_codiag,
df_claude3_opus_t1.0_icd_codiag,
df_gemini1.0_pro_t1.0_icd_codiag,
df_gemini1.5_pro_t1.0_icd_codiag
) %>%
enframe(name = "model", value = "data") %>%
mutate(data = map(data, centrality_wrapper)) %>%
unnest(data) %>%
dplyr::summarise(cosine = mean(cosine), .by = c("V1", "V2")) %>%
pivot_wider(names_from = "V2", values_from = "cosine") %>%
column_to_rownames("V1") %>%
as.matrix()
average_cosine_matrix
SLE - SLICC SLE - EULAR-ACR MCAS - Consortium MCAS - Alternative
SLE - SLICC 1.0000000 0.8540815 0.2276693 0.5103618
SLE - EULAR-ACR 0.8540815 1.0000000 0.2682776 0.4819197
MCAS - Consortium 0.2276693 0.2682776 1.0000000 0.5459547
MCAS - Alternative 0.5103618 0.4819197 0.5459547 1.0000000
custom_heatmap <- function(data,
plot_title=NULL,
legend_title=NULL,
color_scale = NULL,
midpoint = NULL,
symmetric = T,
matrix_title_size=10,
matrix_names_size=8,
legend_title_size=10,
legend_label_size=8,
dendrograms=T,
dendrogram_weight = unit(10, "mm"),
legend_orientation = "vertical",
legend_length=NULL,
grid_lines=F
){
# Determin midpoint of data
scale_max <- ifelse(symmetric,max(abs(data)),max(data))
scale_min <- ifelse(symmetric,-max(abs(data)),min(data))
scale_mid <- scale_min + (scale_max - scale_min)/2
midpoint <- ifelse(is.null(midpoint), scale_mid, midpoint)
# Set default colorscales based on symmetry of data
if (is.null(color_scale) & symmetric){color_scale <- hcl.colors(3, "Earth")}
if (is.null(color_scale) & !symmetric){color_scale <- viridis::viridis(3)}
color_function <-circlize::colorRamp2(c(scale_min,midpoint,scale_max), color_scale)
# Legend parameters
legend_params <- list(
"title_gp" = grid::gpar(fontsize = legend_title_size, fontface = "bold"),
"labels_gp" = grid::gpar(fontsize = legend_label_size),
"direction" = legend_orientation
)
# Heatmap parameters
heatmap_arguments <- list(
"matrix" = data,
"col" = color_function,
"row_names_gp" = grid::gpar(fontsize = matrix_names_size),
"column_names_gp" = grid::gpar(fontsize = matrix_names_size),
"show_column_dend"=dendrograms,
"show_row_dend"=dendrograms
)
if(!is.null(legend_title)){heatmap_arguments[["name"]] <- legend_title}
if(grid_lines){heatmap_arguments[["rect_gp"]] <- grid::gpar(col = "black", lwd = 1)}
if(dendrograms){
heatmap_arguments[['column_dend_height']] <- dendrogram_weight
heatmap_arguments[['row_dend_width']] <- dendrogram_weight
}
legend_side <- ifelse(legend_orientation=="vertical","right","bottom")
if(legend_orientation=="vertical" &!is.null(legend_length)){
legend_params[['legend_height']] <- legend_length}
if(legend_orientation=="horizontal" &!is.null(legend_length)){
legend_params[['legend_width']] <- legend_length}
# Function call
heatmap_arguments[["heatmap_legend_param"]] <- legend_params
ht <- do.call(Heatmap, heatmap_arguments)
draw(ht, heatmap_legend_side=legend_side, align_heatmap_legend = "global_center")
}
custom_heatmap(average_cosine_matrix, symmetric = F, legend_title = "Cosine similarity", grid_lines = T, dendrograms = F, legend_orientation = "horizontal", legend_length=unit(10, "cm"))

df_comp <- listN(
df_gpt3.5_icd_codiag,
df_gpt4.0_icd_codiag,
df_claude3_haiku_t1.0_icd_codiag,
df_claude3_opus_t1.0_icd_codiag,
df_gemini1.0_pro_t1.0_icd_codiag,
df_gemini1.5_pro_t1.0_icd_codiag
) %>%
enframe(name = "model", value = "data") %>%
mutate(data = map(data, centrality_wrapper)) %>%
unnest(data)
cosine_similarity_compare <- function(df, pt_size=1,p_size=3){
Warning message:
In do_once((if (is_R_CMD_check()) stop else warning)("The function xfun::isFALSE() will be deprecated in the future. Please ", :
The function xfun::isFALSE() will be deprecated in the future. Please consider using base::isFALSE(x) or identical(x, FALSE) instead.
# Format data
df <- df %>%
unite(comp, V1, V2) %>%
filter(comp %in% c("MCAS - Consortium_MCAS - Alternative", "SLE - EULAR-ACR_SLE - SLICC")) %>%
mutate(comp = gsub("_", "\nvs. ", comp)) %>%
format_models()
# Plot data
df %>%
ggplot(aes(x = comp, y = cosine, color = model))+
geom_point(size=pt_size, position = position_dodge(width = 0.75))+
theme_bw() +
ggpubr::stat_compare_means(aes(group = comp), method = "wilcox.test", label = "p", vjust = 0.75, show.legend = F, size = p_size)+
theme(axis.text.x = element_text(angle =90, hjust =1)) +
labs(x=NULL, y="Cosine similarity") +
scale_color_brewer(palette = "Dark2") +
labs(color = "")
}
cosine_similarity_compare(df_comp)

cosine_similarity_compare(df_comp)

individual_cosine_matrix <- combine_data_frames(
df_gpt3.5_icd_codiag,
df_gpt4.0_icd_codiag,
df_claude3_haiku_t1.0_icd_codiag,
df_claude3_opus_t1.0_icd_codiag,
df_gemini1.0_pro_t1.0_icd_codiag,
df_gemini1.5_pro_t1.0_icd_codiag,
additional_function = function(x){
x %>%
make_codiagnosis_graph() %>%
calculate_subgraph_centrality() %>%
rownames_to_column("criteria") %>%
pivot_longer(-criteria, names_to="diagnosis", values_to="centrality")
}
) %>%
unite(criteria, original_df, criteria, sep = "-") %>%
pivot_wider(names_from = "diagnosis", values_from="centrality", values_fill=0) %>%
column_to_rownames("criteria") %>%
as.matrix() %>%
t() %>%
lsa::cosine() %>%
as.data.frame() %>%
rownames_to_column("V1") %>%
pivot_longer(-V1, names_to = "V2", values_to = "cosine") %>%
pivot_wider(names_from = "V2", values_from = "cosine") %>%
column_to_rownames("V1") %>%
as.matrix()
individual_cosine_matrix[1:5, 1:5]
df_gpt3.5_icd_codiag-slicc_sle df_gpt3.5_icd_codiag-eular_acr_sle df_gpt3.5_icd_codiag-mcas_consortium df_gpt3.5_icd_codiag-mcas_alternative
df_gpt3.5_icd_codiag-slicc_sle 1.0000000 0.8647922 0.2136086 0.4720033
df_gpt3.5_icd_codiag-eular_acr_sle 0.8647922 1.0000000 0.2643814 0.4613926
df_gpt3.5_icd_codiag-mcas_consortium 0.2136086 0.2643814 1.0000000 0.5519625
df_gpt3.5_icd_codiag-mcas_alternative 0.4720033 0.4613926 0.5519625 1.0000000
df_gpt4.0_icd_codiag-slicc_sle 0.8522184 0.7324445 0.1663721 0.3913603
df_gpt4.0_icd_codiag-slicc_sle
df_gpt3.5_icd_codiag-slicc_sle 0.8522184
df_gpt3.5_icd_codiag-eular_acr_sle 0.7324445
df_gpt3.5_icd_codiag-mcas_consortium 0.1663721
df_gpt3.5_icd_codiag-mcas_alternative 0.3913603
df_gpt4.0_icd_codiag-slicc_sle 1.0000000
alt_heatmap <- model_criteria_heatmap(individual_cosine_matrix,
color_scale = viridis::viridis(3),
title = " ",
metric = "Cosine\nsimilarity",
symmetric = F,
font_size = 8)
alt_heatmap

plt_alt <- cowplot::plot_grid(grid::grid.grabExpr(ComplexHeatmap::draw(alt_heatmap)))
plt_alt

ggsave(here("figures/4_alt_heatmap.pdf"), plot=plt_alt, height = 4, width = 6)
Final plot
Version 1
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50
title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0
apply_text_formatting <- list(
theme(
axis.text = element_text(size = label_size),
axis.title = element_text(size = title_size),
legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
strip.text = element_text(size = label_size + 1),
legend.key.height = unit(0.4, 'cm'),
legend.key.width = unit(0.4, 'cm'),
# legend.key = element_rect(size = margin(0,0,0,0)),
legend.box.background = element_rect(color = "black", size = 1),
legend.margin = margin(
t = legend_y_pad,
r = legend_x_pad+2,
b = legend_y_pad,
l = legend_x_pad
),
legend.key.spacing.y = unit(-1.5, "pt"),
legend.box.spacing = unit(1.0,"pt"),
axis.text.x=element_text(angle=45,hjust=1)
)
)
strip_margin <- 1
strip_formatting <- list(theme(
strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)),
strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
# strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))
set.seed(1234)
plt_graph_icd <-
multi_make_codiagnosis_graph(
threshold_method = "average",
top_n = 100,
layout = "stress",
df_gpt3.5_codiag,
df_gpt4.0_codiag,
df_claude3_haiku_t1.0_codiag,
df_claude3_opus_t1.0_codiag,
df_gemini1.0_pro_t1.0_codiag,
df_gemini1.5_pro_t1.0_codiag,
point_size = 1.25,
border_size = 0.25,
edge_width = 0.5,
edge_alpha = 0.5,
label_text_size = 9,
tick_text_size = 6,
highlight_stroke_multiplier = 3,
legend_height = unit(25, "pt"),
legend_width = unit(10, "pt")
)
plt_edge_icd <- multi_edge_density_plot(
df_gpt3.5_icd_codiag,
df_gpt4.0_icd_codiag,
df_claude3_haiku_t1.0_icd_codiag,
df_claude3_opus_t1.0_icd_codiag,
df_gemini1.0_pro_t1.0_icd_codiag,
df_gemini1.5_pro_t1.0_icd_codiag
) +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_heatmap_icd <-
custom_heatmap(
average_cosine_matrix,
symmetric = F,
legend_title = "Cosine similarity",
grid_lines = T,
dendrograms = T,
legend_orientation = "horizontal",
legend_length = unit(2, "cm"),
matrix_names_size = 6,
legend_title_size = 7.5,
legend_label_size = 6,
dendrogram_weight = unit(2, "mm")
)


plt_fig <- cowplot::plot_grid(
NULL,
cowplot::plot_grid(
NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
),
NULL,
cowplot::plot_grid(
cowplot::plot_grid(NULL, plt_edge_icd ,rel_heights = c(0.1, 1), ncol = 1),
NULL,
cowplot::plot_grid(
NULL,
grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom")),
NULL,
ncol = 1,
rel_heights = c(0.1, 1, 0.1)
),
nrow = 1,
rel_widths = c(0.5, 0.1, 0.6),
labels = c("B","","C"),
axis = 'h', align = 'bt'
),
ncol = 1,
rel_heights = c(0.05, 0.95, 0.01, 1),
labels = c("A")
)
plt_fig

Version 2
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50
title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0
apply_text_formatting <- list(
theme(
axis.text = element_text(size = label_size),
axis.title = element_text(size = title_size),
legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
strip.text = element_text(size = label_size + 1),
legend.key.height = unit(0.4, 'cm'),
legend.key.width = unit(0.4, 'cm'),
# legend.key = element_rect(size = margin(0,0,0,0)),
legend.box.background = element_rect(color = "black", size = 1),
legend.margin = margin(
t = legend_y_pad,
r = legend_x_pad+2,
b = legend_y_pad,
l = legend_x_pad
),
legend.key.spacing.y = unit(-1.5, "pt"),
legend.box.spacing = unit(1.0,"pt"),
axis.text.x=element_text(angle=90, vjust=0.5)
)
)
strip_margin <- 1
strip_formatting <- list(theme(
strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)),
strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
# strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))
set.seed(1234)
plt_graph_icd <-
multi_make_codiagnosis_graph(
threshold_method = "average",
top_n = 100,
layout = "stress",
df_gpt3.5_codiag,
df_gpt4.0_codiag,
df_claude3_haiku_t1.0_codiag,
df_claude3_opus_t1.0_codiag,
df_gemini1.0_pro_t1.0_codiag,
df_gemini1.5_pro_t1.0_codiag,
point_size = 1.25,
border_size = 0.25,
edge_width = 0.5,
edge_alpha = 0.5,
label_text_size = 9,
tick_text_size = 6,
highlight_stroke_multiplier = 3,
legend_height = unit(25, "pt"),
legend_width = unit(10, "pt")
)
plt_edge_icd <- multi_edge_density_plot(
df_gpt3.5_icd_codiag,
df_gpt4.0_icd_codiag,
df_claude3_haiku_t1.0_icd_codiag,
df_claude3_opus_t1.0_icd_codiag,
df_gemini1.0_pro_t1.0_icd_codiag,
df_gemini1.5_pro_t1.0_icd_codiag
) +
apply_text_formatting +
# theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_cosine <- cosine_similarity_compare(df_comp, p_size = 2) +
apply_text_formatting
plt_heatmap_icd <-
custom_heatmap(
average_cosine_matrix,
symmetric = F,
legend_title = "Average\ncosine similarity",
grid_lines = T,
dendrograms = T,
legend_orientation = "horizontal",
legend_length = unit(2.5, "cm"),
matrix_names_size = 6,
legend_title_size = 7.5,
legend_label_size = 6,
dendrogram_weight = unit(2.5, "mm")
)


pd <- -1
plt_fig <- cowplot::plot_grid(
#1
NULL,
#2
cowplot::plot_grid(
NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
),
#3
NULL,
#4
cowplot::plot_grid(
#4.1
cowplot::plot_grid(
#4.1.1
plot_grid(
#4.1.1.1
plt_edge_icd+theme(legend.position = "none"),
#4.1.1.2
plt_cosine+theme(legend.position="none")+ylim(0.4,1),
# rel_widths = c(10,9),
nrow=1,
align="h",axis="bt"
),
#4.1.2
NA,
plot_grid(NULL,get_legend(plt_edge_icd),nrow=1,rel_widths=c(0.15,1)),
NA,
ncol=1,
rel_heights = c(1, 0.04, 0.1,0.1) #4.1._
),
#4.2
cowplot::plot_grid(
#4.2.1
NULL,
#4.2.2
grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom", padding = unit(c(0,0,0,-1), "mm"))),
#4.2.3
NULL,
ncol = 1,
rel_heights = c(0.1, 1, 0.1)
),
nrow = 1,
rel_widths = c(0.6, 0.5), #4._
align = 'h', axis = 'bt'
),
ncol = 1,
rel_heights = c(0.05, 0.95, 0.05, 1)
)
Warning: Multiple components found; returning the first one. To return all, use `return_all = TRUE`.Warning: Cannot convert object of class logical into a grob.Warning: Cannot convert object of class logical into a grob.
plt_fig <- cowplot::ggdraw(plt_fig)+cowplot::draw_plot_label(c("A","B","C","D"), x=c(0,0,0.25,0.52), y=c(1,rep(0.52,3)))
plt_fig

ggsave(here("figures/4_Network_analysis.pdf"), plot=plt_fig, height = 5.5, width = 3.5)
---
title: "Co-diagnosis network analysis"
output: 
  html_notebook:
    toc: true
    toc_float: true
---

```{r, message = F}
library(here)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
```

# Introduction

- Goal is to model diagnoses from each criteria based on how frequently they co-occur together within a single 10-point differential diagnosis iteration.
- The rationale is that an ideal set of criteria will tend to result in a relatively similar set of diagnoses in each iteration and thus a small set of highly co-occurring diagnoses.
- Conversely, a poor set of criteria will generate highly variable diagnoses in each iteration and rates of co-occurrence will be comparatively lower.

# Import and process data

```{r, message = F}
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
```

```{r, message = F}
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
```

```{r}
# Tally each pair of co-occuring diagnoses within each criteria
# Original responses
df_gpt3.5_codiag <- create_codiagnosis_df(df_gpt3.5, remove_singletons = T)
df_gpt4.0_codiag <- create_codiagnosis_df(df_gpt4.0, remove_singletons = T)
df_claude3_haiku_t1.0_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0, remove_singletons = T)
df_claude3_opus_t1.0_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0, remove_singletons = T)
df_gemini1.0_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0, remove_singletons = T)
df_gemini1.5_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0, remove_singletons = T)
```


```{r}
# ICD converted responses
df_gpt3.5_icd_codiag <- create_codiagnosis_df(df_gpt3.5_icd, remove_singletons = T)
df_gpt4.0_icd_codiag <- create_codiagnosis_df(df_gpt4.0_icd, remove_singletons = T)
df_claude3_haiku_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0_icd, remove_singletons = T)
df_claude3_opus_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0_icd, remove_singletons = T)
df_gemini1.0_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0_icd, remove_singletons = T)
df_gemini1.5_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0_icd, remove_singletons = T)
```


```{r}
df_gpt4.0_codiag
```

# Graph visualization

## Exploring layouts 

- To determine clearest visualization of nodes and edges

```{r, fig.width=8, fig.height=6, warning=F, message=F}
# Selecting a layout
top_n <-  200
seed <- 1234

layouts <- c("fr", "dh", "kk", "stress", "graphopt")

graph_top_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)

for (l in layouts){
  set.seed(seed)
  plt <- centrality_graph(graph_top_gpt4, layout = l)
  plt <- plt + ggtitle("GPT4", subtitle = sprintf("Layout %s", l))
  print(plt)
}

```
## Individual model data

### Original responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
top_n <-  100
seed <- 321
graph_layout <- "stress"

codiag_graph_wrapper <- function(data){
  set.seed(seed)
  data <- make_codiagnosis_graph(data, n_diagnoses = top_n)
  centrality_graph(data, layout = graph_layout)
}

codiag_graph_wrapper(df_gpt3.5_codiag)+ ggtitle("ChatGPT 3.5", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gpt4.0_codiag)+ ggtitle("ChatGPT 4.0", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_haiku_t1.0_codiag)+ ggtitle("Claude 3 Haiku", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_opus_t1.0_codiag)+ ggtitle("Claude 3 Opus", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.0_pro_t1.0_codiag)+ ggtitle("Gemini 1.0 Pro", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.5_pro_t1.0_codiag)+ ggtitle("Gemini 1.5 Pro", subtitle = sprintf("Top %s", top_n))
```
### ICD converted respones

```{r, fig.width=8, fig.height=6, warning=F, message=F}
codiag_graph_wrapper(df_gpt3.5_icd_codiag)+ ggtitle("ChatGPT 3.5 ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gpt4.0_icd_codiag)+ ggtitle("ChatGPT 4.0 ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_haiku_t1.0_icd_codiag)+ ggtitle("Claude 3 Haiku ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_opus_t1.0_icd_codiag)+ ggtitle("Claude 3 Opus ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.0_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.0 Pro ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.5_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.5 Pro ICD", subtitle = sprintf("Top %s", top_n))
```

## Combined model data

### Original responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)
```

```{r, fig.width=8, fig.height=6, warning=F, message=F}
multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)
```
### ICD converted responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)
```


```{r, fig.width=8, fig.height=6, warning=F, message=F}
multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)
```

# Evaluate edge density

- Edge density represents the total number of edges in a graph relative to the total number of **possible** edges in the graph
  - When all possible edges are present, edge density = 1
  - When no edges are present, edge density = 0
- Dense co-occurrence networks represent criteria that generate highly reproducible diagnoses. A sparse co-occurrence represents a high degree of variability 

## Combined model data

### Original responses

```{r, fig.width=4, fig.height=3.5}
multi_edge_density_plot(
  df_gpt3.5_codiag,
  df_gpt4.0_codiag,
  df_claude3_haiku_t1.0_codiag,
  df_claude3_opus_t1.0_codiag,
  df_gemini1.0_pro_t1.0_codiag,
  df_gemini1.5_pro_t1.0_codiag
)  

```

### ICD converted responses

```{r, fig.width=4, fig.height=3.5}
plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  
plt_edge_icd
plt_edge_icd$data %>% summarise(mean(edge_density), .by="criteria")
extract_ggpubr_pvalues(plt_edge_icd)
```



# Plot centrality
```{r}
calculate_subgraph_centrality <- function(g, centrality_fun = "centrality_eigen"){
  data.frame(criteria = g %>% activate(edges) %>% pull(criteria) %>% unique()) %>% 
    mutate(sub_graphs = map(criteria, function(c){
      g %>% as_data_frame() %>% filter(criteria == c) %>% as_tbl_graph(directed = F) %>% 
        activate(nodes) %>% mutate(centrality = get(centrality_fun)()) %>% data.frame()
    })) %>% 
  unnest(sub_graphs) %>% 
  pivot_wider(names_from = "name", values_from = "centrality", values_fill = 0) %>% 
  column_to_rownames("criteria")
}

graph_all_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)
df_centr_gpt4 <- calculate_subgraph_centrality(graph_all_gpt4)
```

```{r}
centrality_similarity <- function(data){
  data %>% 
    rownames_to_column("criteria") %>% 
    format_criteria() %>% 
    column_to_rownames("criteria") %>% 
    as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine")
}

centrality_similarity(df_centr_gpt4)
```
```{r}
centrality_wrapper <- function(data, n_diag=NULL){
  make_codiagnosis_graph(data, n_diagnoses = n_diag) %>% 
    calculate_subgraph_centrality() %>% 
    centrality_similarity()
}

average_cosine_matrix <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data) %>%
  dplyr::summarise(cosine = mean(cosine), .by = c("V1", "V2")) %>%
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

average_cosine_matrix


```



```{r, fig.width=5.5, fig.height=3.5}
custom_heatmap <- function(data, 
                           plot_title=NULL, 
                           legend_title=NULL, 
                           color_scale = NULL, 
                           midpoint = NULL, 
                           symmetric = T, 
                           matrix_title_size=10,
                           matrix_names_size=8,
                           legend_title_size=10,
                           legend_label_size=8,
                           dendrograms=T,
                           dendrogram_weight = unit(10, "mm"),
                           legend_orientation = "vertical",
                           legend_length=NULL,
                           grid_lines=F
                           ){
  
  # Determin midpoint of data
  scale_max <- ifelse(symmetric,max(abs(data)),max(data))
  scale_min <- ifelse(symmetric,-max(abs(data)),min(data))
  scale_mid <- scale_min + (scale_max - scale_min)/2
  
  midpoint <- ifelse(is.null(midpoint), scale_mid, midpoint)
  
  # Set default colorscales based on symmetry of data
  if (is.null(color_scale) & symmetric){color_scale <- hcl.colors(3, "Earth")}
  if (is.null(color_scale) & !symmetric){color_scale <- viridis::viridis(3)}
  color_function <-circlize::colorRamp2(c(scale_min,midpoint,scale_max), color_scale)
  
  # Legend parameters
  legend_params <- list(
      "title_gp" = grid::gpar(fontsize = legend_title_size, fontface = "bold"),
      "labels_gp" = grid::gpar(fontsize = legend_label_size),
      "direction" = legend_orientation
      )
  
  # Heatmap parameters
  heatmap_arguments <- list(
    "matrix" = data,
    "col" = color_function,
    "row_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "column_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "show_column_dend"=dendrograms,
    "show_row_dend"=dendrograms
  )
  
  if(!is.null(legend_title)){heatmap_arguments[["name"]] <- legend_title}
  if(grid_lines){heatmap_arguments[["rect_gp"]] <- grid::gpar(col = "black", lwd = 1)}
  if(dendrograms){
    heatmap_arguments[['column_dend_height']] <- dendrogram_weight
    heatmap_arguments[['row_dend_width']] <- dendrogram_weight
  }
  
  legend_side <- ifelse(legend_orientation=="vertical","right","bottom")
  if(legend_orientation=="vertical" &!is.null(legend_length)){
    legend_params[['legend_height']] <- legend_length}
  if(legend_orientation=="horizontal" &!is.null(legend_length)){
    legend_params[['legend_width']] <- legend_length}
  # Function call
  heatmap_arguments[["heatmap_legend_param"]] <- legend_params
  ht <- do.call(Heatmap, heatmap_arguments)
  draw(ht, heatmap_legend_side=legend_side, align_heatmap_legend = "global_center")
  
}

custom_heatmap(average_cosine_matrix, symmetric = F, legend_title = "Cosine similarity", grid_lines = T, dendrograms = F, legend_orientation = "horizontal", legend_length=unit(10, "cm"))
```
```{r}
df_comp <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data)
```

```{r, fig.width=3.5, fig.height=4}
cosine_similarity_compare <- function(df, pt_size=1,p_size=3){
  # Format data
  df <- df %>% 
    unite(comp, V1, V2) %>% 
    filter(comp %in% c("MCAS - Consortium_MCAS - Alternative", "SLE - EULAR-ACR_SLE - SLICC")) %>% 
    mutate(comp = gsub("_", "\nvs. ", comp)) %>%  
    format_models() 
  
  # Plot data
  df %>% 
    ggplot(aes(x = comp, y = cosine, color = model))+
    geom_point(size=pt_size, position = position_dodge(width = 0.75))+
    theme_bw() +
    ggpubr::stat_compare_means(aes(group = comp), method = "wilcox.test", label = "p", vjust = 0.75, show.legend = F, size = p_size)+
    theme(axis.text.x = element_text(angle =90, hjust =1)) +
    labs(x=NULL, y="Cosine similarity") +
    scale_color_brewer(palette = "Dark2") +
    labs(color = "")
}
cosine_similarity_compare(df_comp)
```

```{r}
individual_cosine_matrix <- combine_data_frames(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag,
  additional_function = function(x){
    x %>% 
      make_codiagnosis_graph() %>% 
      calculate_subgraph_centrality() %>% 
      rownames_to_column("criteria") %>% 
      pivot_longer(-criteria, names_to="diagnosis", values_to="centrality")
  }
) %>% 
  unite(criteria, original_df, criteria, sep = "-") %>% 
  pivot_wider(names_from = "diagnosis", values_from="centrality", values_fill=0) %>% 
  column_to_rownames("criteria") %>% 
  as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine") %>% 
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

individual_cosine_matrix[1:5, 1:5]
```

```{r, fig.width=5.5, fig.height=3.5}
alt_heatmap <- model_criteria_heatmap(individual_cosine_matrix, 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Cosine\nsimilarity", 
                symmetric = F,
                font_size = 8)  
  

alt_heatmap
```
```{r, fig.width=6, fig.height=4}
plt_alt <- cowplot::plot_grid(grid::grid.grabExpr(ComplexHeatmap::draw(alt_heatmap)))
plt_alt
```
```{r}
ggsave(here("figures/4_alt_heatmap.pdf"), plot=plt_alt, height = 4, width = 6)
```

# Final plot

### Version 1
```{r, fig.width=3.5, fig.height=5.5}
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad+2,
      b = legend_y_pad,
      l = legend_x_pad 
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(1.0,"pt"),
    axis.text.x=element_text(angle=45,hjust=1)
  )
)


strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

set.seed(1234)
plt_graph_icd <-
  multi_make_codiagnosis_graph(
    threshold_method = "average",
    top_n = 100,
    layout = "stress",
    df_gpt3.5_codiag,
    df_gpt4.0_codiag,
    df_claude3_haiku_t1.0_codiag,
    df_claude3_opus_t1.0_codiag,
    df_gemini1.0_pro_t1.0_codiag,
    df_gemini1.5_pro_t1.0_codiag,
    point_size = 1.25,
    border_size = 0.25,
    edge_width = 0.5,
    edge_alpha = 0.5,
    label_text_size = 9,
    tick_text_size = 6,
    highlight_stroke_multiplier = 3,
    legend_height = unit(25, "pt"),
    legend_width = unit(10, "pt")
  ) 

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
  scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))

plt_heatmap_icd <-
  custom_heatmap(
    average_cosine_matrix,
    symmetric = F,
    legend_title = "Cosine similarity",
    grid_lines = T,
    dendrograms = T,
    legend_orientation = "horizontal",
    legend_length = unit(2, "cm"),
    matrix_names_size = 6, 
    legend_title_size = 7.5,
    legend_label_size = 6,
    dendrogram_weight = unit(2, "mm")
  )


plt_fig <- cowplot::plot_grid(
  NULL, 
  cowplot::plot_grid(
    NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
  ),
  NULL,
  cowplot::plot_grid(
    cowplot::plot_grid(NULL, plt_edge_icd ,rel_heights = c(0.1, 1), ncol = 1),
    NULL,
    cowplot::plot_grid(
      NULL,
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom")),
      NULL,
      ncol = 1,
      rel_heights = c(0.1, 1, 0.1)
    ),
    nrow = 1,
    rel_widths = c(0.5, 0.1, 0.6),
    labels = c("B","","C"),
    axis = 'h', align = 'bt'
  ),

  ncol = 1,
  rel_heights = c(0.05, 0.95, 0.01, 1),
  labels = c("A")
)

plt_fig 
```

### Version 2
```{r, fig.width=3.5, fig.height=5.5}
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad+2,
      b = legend_y_pad,
      l = legend_x_pad 
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(1.0,"pt"),
    axis.text.x=element_text(angle=90, vjust=0.5)
  )
)


strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

set.seed(1234)
plt_graph_icd <-
  multi_make_codiagnosis_graph(
    threshold_method = "average",
    top_n = 100,
    layout = "stress",
    df_gpt3.5_codiag,
    df_gpt4.0_codiag,
    df_claude3_haiku_t1.0_codiag,
    df_claude3_opus_t1.0_codiag,
    df_gemini1.0_pro_t1.0_codiag,
    df_gemini1.5_pro_t1.0_codiag,
    point_size = 1.25,
    border_size = 0.25,
    edge_width = 0.5,
    edge_alpha = 0.5,
    label_text_size = 9,
    tick_text_size = 6,
    highlight_stroke_multiplier = 3,
    legend_height = unit(25, "pt"),
    legend_width = unit(10, "pt")
  ) 

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  +
  apply_text_formatting +
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
  scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))

plt_cosine <- cosine_similarity_compare(df_comp, p_size = 2) +
  apply_text_formatting

plt_heatmap_icd <-
  custom_heatmap(
    average_cosine_matrix,
    symmetric = F,
    legend_title = "Average\ncosine similarity",
    grid_lines = T,
    dendrograms = T,
    legend_orientation = "horizontal",
    legend_length = unit(2.5, "cm"),
    matrix_names_size = 6, 
    legend_title_size = 7.5,
    legend_label_size = 6,
    dendrogram_weight = unit(2.5, "mm")
  )

pd <- -1

plt_fig <- cowplot::plot_grid(
  #1
  NULL, 
  #2
  cowplot::plot_grid(
    NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
  ),
  #3
  NULL,
  #4
  cowplot::plot_grid(
    #4.1
    cowplot::plot_grid(
      #4.1.1
      plot_grid(
        #4.1.1.1
        plt_edge_icd+theme(legend.position = "none"),
        #4.1.1.2
        plt_cosine+theme(legend.position="none")+ylim(0.4,1), 
        # rel_widths = c(10,9),
        nrow=1,
        align="h",axis="bt"
        ),
      #4.1.2
      NA,
      plot_grid(NULL,get_legend(plt_edge_icd),nrow=1,rel_widths=c(0.15,1)),
      NA,
      ncol=1,
      rel_heights = c(1, 0.04, 0.1,0.1) #4.1._
    ),
    #4.2
    cowplot::plot_grid(
      #4.2.1
      NULL,
      #4.2.2
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom", padding = unit(c(0,0,0,-1), "mm"))),
      #4.2.3
      NULL,
      ncol = 1,
      rel_heights = c(0.1, 1, 0.1)
    ),
    nrow = 1,
    rel_widths = c(0.6,  0.5), #4._
    align = 'h', axis = 'bt'
  ),

  ncol = 1,
  rel_heights = c(0.05, 0.95, 0.05, 1)
)

plt_fig <- cowplot::ggdraw(plt_fig)+cowplot::draw_plot_label(c("A","B","C","D"), x=c(0,0,0.25,0.52), y=c(1,rep(0.52,3)))
plt_fig 
```


```{r, eval=F}
ggsave(here("figures/4_Network_analysis.pdf"), plot=plt_fig, height = 5.5, width = 3.5)
```

